{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hydra Configuration\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/ai4co/rl4co/blob/main/examples/advanced/1-hydra-config.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
    "\n",
    "[Hydra](https://hydra.cc/docs/intro/) makes it extremely convenient to configure projects with lots of parameter settings like the RL4CO library. \n",
    "\n",
    "While you don't need Hydra to use RL4CO, it is recommended to use it for your own projects to make it easier to manage the configuration of your experiments.\n",
    "\n",
    "Hydra uses config files in `.yaml` format for this. These files can be found in the [configs/](../../../configs/) folder, where the subfolders define configurations for specific parts of the framework which are then combined in the [main.yaml](../../../configs/main.yaml) configuration. In this tutorial we will have a look at how to use these different configuration files and how to add new parameters to the configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "ROOT_DIR = \"../../\" # relative to this file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# context initialization\n",
    "with initialize(version_base=None, config_path=ROOT_DIR+\"configs\"):\n",
    "    cfg = compose(config_name=\"main\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hydra stores the configurations in a dictionary like object called OmegaConf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "omegaconf.dictconfig.DictConfig"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The different subfolders in the configs folder are represented as distinct keys in the omegaconf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['mode',\n",
       " 'tags',\n",
       " 'train',\n",
       " 'test',\n",
       " 'compile',\n",
       " 'ckpt_path',\n",
       " 'seed',\n",
       " 'matmul_precision',\n",
       " 'model',\n",
       " 'callbacks',\n",
       " 'logger',\n",
       " 'trainer',\n",
       " 'paths',\n",
       " 'extras',\n",
       " 'env']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(cfg.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Keys can be accessed using the dot notation (e.g. `cfg.model`) or via normal dictionaries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "print(cfg.model == cfg[\"model\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dot notation is however more convenient especially in nested structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "print(cfg.model._target_ == cfg[\"model\"][\"_target_\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example, lets look at the model configuration (which corresponds the [model/default.yaml](../../../configs/model/default.yaml) configuration). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generate_default_data: true\n",
      "metrics:\n",
      "  train:\n",
      "  - loss\n",
      "  - reward\n",
      "  val:\n",
      "  - reward\n",
      "  test:\n",
      "  - reward\n",
      "  log_on_step: true\n",
      "_target_: rl4co.models.AttentionModel\n",
      "baseline: rollout\n",
      "batch_size: 512\n",
      "val_batch_size: 1024\n",
      "test_batch_size: 1024\n",
      "train_data_size: 1280000\n",
      "val_data_size: 10000\n",
      "test_data_size: 10000\n",
      "optimizer_kwargs:\n",
      "  lr: 0.0001\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(OmegaConf.to_yaml(cfg.model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we want to change parts of the configuration, it is generally a good practice to make the changes via the command line when executing the respective python script (in the case of RL4CO for example [rl4co/tasks/train.py](../../../rl4co/tasks/train.py)). For example, if we want to use a different model configuration, we can do something like:\n",
    "\n",
    "```bash\n",
    "python train.py model=pomo model.batch_size=32\n",
    "```\n",
    "\n",
    "Here we use the model/pomo.yaml configuration for the model and also change the batch size during training to 32. \n",
    "\n",
    "> Note: check out the see [override syntax documentation](https://hydra.cc/docs/1.1/advanced/override_grammar/basic/) on the Hydra website for more!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generate_default_data: true\n",
      "metrics:\n",
      "  train:\n",
      "  - loss\n",
      "  - reward\n",
      "  val:\n",
      "  - reward\n",
      "  - max_reward\n",
      "  - max_aug_reward\n",
      "  test: ${metrics.val}\n",
      "  log_on_step: true\n",
      "_target_: rl4co.models.POMO\n",
      "num_augment: 8\n",
      "batch_size: 32\n",
      "val_batch_size: 1024\n",
      "test_batch_size: 1024\n",
      "train_data_size: 1280000\n",
      "val_data_size: 10000\n",
      "test_data_size: 10000\n",
      "optimizer_kwargs:\n",
      "  lr: 0.0001\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with initialize(version_base=None, config_path=ROOT_DIR+\"configs\"):\n",
    "    cfg = compose(config_name=\"main\", overrides=[\"model=pomo\",\"model.batch_size=32\"])\n",
    "    print(OmegaConf.to_yaml(cfg.model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also possible to add new parameters to a config using the `+` prefix. Using `++` will add a new parameter if it does not exist and _overwrite_ it if it does. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generate_default_data: true\n",
      "metrics:\n",
      "  train:\n",
      "  - loss\n",
      "  - reward\n",
      "  val:\n",
      "  - reward\n",
      "  - max_reward\n",
      "  - max_aug_reward\n",
      "  test: ${metrics.val}\n",
      "  log_on_step: true\n",
      "_target_: rl4co.models.POMO\n",
      "num_augment: 8\n",
      "batch_size: 32\n",
      "val_batch_size: 1024\n",
      "test_batch_size: 1024\n",
      "train_data_size: 1280000\n",
      "val_data_size: 10000\n",
      "test_data_size: 10000\n",
      "optimizer_kwargs:\n",
      "  lr: 0.0001\n",
      "num_starts: 10\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with initialize(version_base=None, config_path=ROOT_DIR+\"configs\"):\n",
    "    cfg = compose(config_name=\"main\", overrides=[\"model=pomo\",\"model.batch_size=32\",\"+model.num_starts=10\"])\n",
    "    print(OmegaConf.to_yaml(cfg.model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Likewise, we can also remove unwanted parts of the configuration. For example, if we do not want to use any experiment configuration, we can remove the changes to the configuration made by [experiments/base.yaml](../../../configs/experiment/base.yaml) using the `~` prefix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generate_default_data: true\n",
      "metrics:\n",
      "  train:\n",
      "  - loss\n",
      "  - reward\n",
      "  val:\n",
      "  - reward\n",
      "  - max_reward\n",
      "  - max_aug_reward\n",
      "  test: ${metrics.val}\n",
      "  log_on_step: true\n",
      "_target_: rl4co.models.POMO\n",
      "num_augment: 8\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with initialize(version_base=None, config_path=ROOT_DIR+\"configs\"):\n",
    "    cfg = compose(config_name=\"main\", overrides=[\"model=pomo\",\"~experiment\"])\n",
    "    print(OmegaConf.to_yaml(cfg.model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, parameters like \"batch_size\" were removed from the model config, as those were set by the experiment config base.yaml. Through the hashbang\n",
    "```\n",
    "# @package _global_\n",
    "```\n",
    "in the [configs/experiments/base.yaml](../../../configs/experiment/base.yaml), this configuration is able to make changes to all parts of the configuration (like model, trainer, logger). So instead of adding a new key to the omegaconf object, configurations with a ```# @package _global_``` hashbang typically alter other parts of the configuration. \n",
    "\n",
    "Another example of such a configuration is the debug/default.yaml, which sets all parameters into a lightweight debugging mode:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generate_default_data: true\n",
      "metrics:\n",
      "  train:\n",
      "  - loss\n",
      "  - reward\n",
      "  val:\n",
      "  - reward\n",
      "  test:\n",
      "  - reward\n",
      "  log_on_step: true\n",
      "_target_: rl4co.models.AttentionModel\n",
      "baseline: rollout\n",
      "batch_size: 8\n",
      "val_batch_size: 32\n",
      "test_batch_size: 32\n",
      "train_data_size: 64\n",
      "val_data_size: 1000\n",
      "test_data_size: 1000\n",
      "optimizer_kwargs:\n",
      "  lr: 0.0001\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with initialize(version_base=None, config_path=ROOT_DIR+\"configs\"):\n",
    "    cfg = compose(config_name=\"main\", overrides=[\"debug=default\"])\n",
    "    print(OmegaConf.to_yaml(cfg.model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- Reference config files using the CLI flag ```<key>=<config_file>``` (e.g. ```model=am```)\n",
    "- Add parameters (or even entire keys) to the config using the \"+\" prefix (e.g. ```+model.batch_size=32```)\n",
    "- Remove parameters (or even entire keys) to the config using the \"~\" prefix (e.g. ```~logger.wandb```)\n",
    "- The ```# @package _global_``` hashbang allows global access from any config file\n",
    "- Turn on debugging mode using ```debug=default```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl4co",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
